from read_dataset import read_dataset
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import math



def real_data(l, seed):
    np.random.seed(seed)
    n,d,X,Y = read_dataset(l)
    XY = np.hstack([X,Y.reshape(-1,1)])
    XY = np.random.permutation(XY)
    X = XY[:,:-1]
    Y = XY[:,-1]
    assert X.shape==(n,d); Y.shape==(n,1)
    return n,d,torch.Tensor(X),torch.Tensor(Y)


def obj(x, X, Y, n, lambd):
    return torch.linalg.norm(torch.matmul(X,x)-Y)**2/n+lambd*torch.linalg.norm(x)**2

def gradient(x, X, Y, n, lambd):
    return 2*torch.matmul(X.t(),torch.matmul(X,x)-Y)/n + 2*lambd*x 

def hessian(X, n, lambd, d):
    return 2*torch.mm(X.t(),X)/n + 2*lambd*torch.eye(d)

def sh(m, n, d, X, lambd, seed):
    torch.random.manual_seed(seed)
    
    ds = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(n), (1/m)*torch.eye(n))
    S = ds.sample((m,))
    assert S.shape == (m,n)
    SX = torch.mm(S,X)
    return 2*torch.mm(SX.t(),SX)/n + 2*lambd*torch.eye(d)

def sh_shrink(X, n, lambd, d, m, seed):
    dl = torch.trace(torch.mm(2*torch.mm(X.t(),X)/n,torch.linalg.inv(2*torch.mm(X.t(),X)/n + 2*lambd*torch.eye(d))))
    return (1/(1-dl/m))*(sh(m, n, d, X, lambd, seed)-2*lambd*torch.eye(d)) + 2*lambd*torch.eye(d)

def sh_shrink_s(X, n, lambd, d, m, seed):
    SX = (sh(m, n, d, X, lambd, seed)-2*lambd*torch.eye(d))*n/2
    dl = torch.trace(torch.mm(2*(SX/m)/n,torch.linalg.inv(2*(SX/m)/n + 2*lambd*torch.eye(d))))
    return (1/(1-dl/m))*(sh(m, n, d, X, lambd, seed)-2*lambd*torch.eye(d)) + 2*lambd*torch.eye(d)


def sample_sketch(m, d, seed):
    torch.random.manual_seed(seed)
    ds = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(d), (1/m)*torch.eye(d))
    S = ds.sample((m,))
    assert S.shape == (m,d)
    return S

def compute_sm(H_S, val):
    sketch_dim = H_S.shape[0]
    return torch.trace(torch.inverse(H_S-val*torch.eye(sketch_dim)))/sketch_dim

def find_l_hat(H_S,l):
    search_range = [5*l/12, l] 
    mid_point = sum(search_range)/2
    if compute_sm(H_S,-search_range[0])>=1/l and compute_sm(H_S,-search_range[1])<=1/l:
        mid_point_val = compute_sm(H_S,-mid_point)
        while torch.abs(mid_point_val-1/l)<1e-6:
            if mid_point_val>1/l:
                search_range[0] = mid_point
            else:
                search_range[1] = mid_point 
            mid_point = sum(search_range)/2
            mid_point_val = compute_sm(H_S,-mid_point) 
        return mid_point
    else:
        print('Error in assertion')
        return 5*l/12 

def debias(H, l, d, seed, q):
    """
    our version of Hessian inverse estimation :
       step 1: compute sketch dimension m
       step 2: compute l_hat
       step 3: sketch and return
    """
    worker_sum = 0
    for _ in range(q):
        # step 1
        m = 2
        while m<d:
            S = sample_sketch(m,d,seed)
            if compute_sm(S@H@S.T,-5*l/12)>1/l:
                break
            else:
                m = 2*m 
        # step 2
        S = sample_sketch(m,d,seed) 
        if compute_sm(S@H@S.T,0)<=1/l:
            l_hat = 5*l/12 
        else:
            l_hat = find_l_hat(S@H@S.T,l)
        # step 3
        worker_sum += S.T@torch.inverse(S@H@S.T+l_hat*torch.eye(S.shape[0]))@S
    return worker_sum/q


def nodebias(H, l, d, seed, q):
    """
    our version of Hessian inverse estimation :
       step 1: compute sketch dimension m
       step 2: compute l_hat
       step 3: sketch and return
    """
    worker_sum = 0
    for _ in range(q):
        m = 2
        while m<d:
            S = sample_sketch(m,d,seed)
            if compute_sm(S@H@S.T,-5*l/12)>1/l:
                break
            else:
                m = 2*m 
        worker_sum += S.T@torch.inverse(S@H@S.T+l*torch.eye(S.shape[0]))@S
    return worker_sum/q


def exact_dl(X, H, l, d, seed, q):
    worker_sum = 0
    for _ in range(q):
        # step 1
        n = X.shape[0]
        dl = torch.trace(torch.mm(2*torch.mm(X.t(),X)/n,torch.linalg.inv(2*torch.mm(X.t(),X)/n + l*torch.eye(d))))
        m = math.ceil(dl)+1
        print('exact m: ', m)
        # step 2
        S = sample_sketch(m,d,seed) 
        l_hat = l*(1-dl/m)
        # step 3
        worker_sum += S.T@torch.inverse(S@H@S.T+l_hat*torch.eye(S.shape[0]))@S
    return worker_sum/q
    
def line_search(x, del_x, dec,  X, Y, n, lambd) -> tuple[float,int]:
    alpha = 1/4
    beta = 1/2
    t = 1
    i = 0
    val_prev = obj(x, X, Y, n, lambd)
    val = obj(x+t*del_x, X, Y, n, lambd)
    dec = alpha*dec
    while (val>val_prev+t*dec).any() and i<10:
        i += 1
        t = beta*t
        val = obj(x+t*del_x, X, Y, n, lambd)
    return t, i

def newton(x_init, X, Y, n, lambd, m, d, seed, method, newton_iter=50, newton_eps=1e-4):
    x = x_init
    search = 0
    stop_reason = 'max_iter reached'
    record = []
    for i in range(newton_iter):
        g = gradient(x, X, Y, n, lambd)
        if method == 'h':
            h = hessian(X, n, lambd, d)
        if method == 'ihs':
            h = sh(m, n, d, X, lambd, seed)
        if method == 'ihs_s':
            h = sh_shrink(X, n, lambd, d, m, seed)
        if method == 'ihs_ss':
            h = sh_shrink_s(X, n, lambd, d, m, seed)
        if (not method == 'debias') and (not method=='nodebias'):
            del_x = h@g 
        elif method == 'debias':
            H = 2*torch.mm(X.t(),X)/n
            l = 2*lambd
            q = 10
            h_inv = debias(H, l, d, seed, q)
            del_x = h_inv@g
        elif method == 'nodebias':
            H = 2*torch.mm(X.t(),X)/n
            l = 2*lambd
            q = 10
            h_inv = nodebias(H, l, d, seed, q)
            del_x = h_inv@g
        elif method=='e_debias':
            H = 2*torch.mm(X.t(),X)/n
            l = 2*lambd
            q = 10
            h_inv = exact_dl(X, H, l, d, seed, q)
            del_x = h_inv@g
        dec = g.t()@del_x
        record.append((i,obj(x, X, Y, n, lambd)))
        # print(f'{dec/2=}')
        if torch.abs(dec/2) <= newton_eps:
            stop_reason = 'converged'
            break
        t, search_iters = line_search(x,-del_x, dec, X, Y, n, lambd)
        search += search_iters
        x = x-t*del_x    
    print(stop_reason)       
    return (record, x,  i,  stop_reason)

def plot(X, Y, n, lambd, m, d, seed):
    x_init = torch.zeros(d)
    newton_ihs = newton(x_init, X, Y, n, lambd, m, d, seed, method='ihs')[0]
    newton_ihs_s = newton(x_init, X, Y, n, lambd, m, d, seed, method='ihs_s')[0]
    newton(x_init, X, Y, n, lambd, m, d, seed, method='h')
    plt.plot(torch.Tensor(newton_ihs).cpu().numpy()[:,0],torch.Tensor(newton_ihs).cpu().numpy()[:,1])
    plt.plot(torch.Tensor(newton_ihs_s).cpu().numpy()[:,0],torch.Tensor(newton_ihs_s).cpu().numpy()[:,1])
    plt.yscale('log')
    plt.legend(['IHS','IHS With Shrinkage'])
    plt.savefig('ihs.pdf')

def get_x_axis(data_set):
    x_max = -1
    length = 0
    for i in data_set: 
        if i[-1][0]>x_max:
            length = len(i)
            x_max = i[-1][0]
    if length == 2:
        return np.linspace(0, x_max, 2)
    return np.linspace(0, x_max, max(length,100))

def interpolate(data, axis, optimals):
    numbers = np.zeros((len(data), len(axis)))
    for i in range(len(data)):
        numbers[i] = get_numbers(data[i], axis) 
    assert numbers.shape == (len(data), len(axis))
    mean = np.quantile(numbers, 0.5, axis=0)
    error_l = np.quantile(numbers, 0.2, axis=0)
    error_u = np.quantile(numbers, 0.8, axis=0) 
    return (mean, error_l, error_u) 

def get_numbers(data, axis, optimal=None): 
    data = np.array(torch.Tensor(data))
    if optimal is None:
        y = np.abs(data[:,1]-data[-1,1])/np.abs(data[-1,1])
    else:
        optimal = np.array(torch.Tensor.cpu(optimal))
        y = np.abs(data[:,1]-optimal)/np.abs(optimal)
    res = np.interp(axis,data[:,0],y)
    return res

def plot_multi_realdata(data):
    lambd = 0.1
    newton_debias = []
    newton_nodebias = []
    x_axis = {}
    optimal = []
    for i in range(10):
        print(i)
        n,d, X,Y = real_data(data, i)
        dl = torch.trace(torch.mm(2*torch.mm(X.t(),X)/n,torch.linalg.inv(2*torch.mm(X.t(),X)/n + 2*lambd*torch.eye(d))))
        print('dl=',dl)
        x_init = torch.zeros(d,)
        m = 100
        max_iter=200
        k3 = newton(x_init, X, Y, n, lambd, m, d, newton_iter=max_iter, newton_eps=1e-6, seed=i, method='debias')
        if k3[-1]=='max_iter reached':
            continue 
        k5 = newton(x_init, X, Y, n, lambd, m, d, newton_iter=max_iter, newton_eps=1e-6, seed=i, method='nodebias') 
        if k5[-1]=='max_iter reached':
            optimal.append(k3[0][-1])
        newton_debias.append(k3[0])
        newton_nodebias.append(k5[0])
    x_axis['debias'] = get_x_axis(newton_debias)
    x_axis['nodebias'] = get_x_axis(newton_nodebias)
    plot_data = {}
    plot_data['debias'] = interpolate(newton_debias, x_axis['debias'], optimal)
    plot_data['nodebias'] = interpolate(newton_nodebias, x_axis['nodebias'], optimal)
    plt.figure(figsize=(100, 100))
    fig, ax = plt.subplots()
    clrs = sns.color_palette("husl", 5)
    clrs[0] = (0.9677975592919913, 0.44127456009157356, 0.5358103155058701)
    clrs[1] = (0.3126890019504329, 0.6928754610296064, 0.1923704830330379)
    clrs[2] = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
    ax.plot(x_axis['nodebias'], plot_data['nodebias'][0], label='No Debiasing', c=clrs[0])
    ax.fill_between(x_axis['nodebias'], plot_data['nodebias'][1], plot_data['nodebias'][2],alpha=0.3, facecolor=clrs[0])
    ax.plot(x_axis['debias'], plot_data['debias'][0], label='Debiasing (Ours)', c=clrs[1])
    ax.fill_between(x_axis['debias'], plot_data['debias'][1], plot_data['debias'][2],alpha=0.3, facecolor=clrs[1])
    ax.legend(fontsize=18, loc="upper right")
    ax.set_yscale('log')
    plt.xlabel('Newton Steps', fontsize=20)
    plt.title(data+' (ridge regression)', fontsize=20)
    plt.ylabel('Log Optimality Gap', fontsize=20)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)
    plt.tight_layout()
    plt.savefig('uci_ridge-%s.pdf'%(data))


if __name__ == '__main__':
    plot_multi_realdata(sys.argv[1])

